import os

import joblib
import numpy as np
import torch


class BaseCallbacks:
    def __init__(self):
        self._last_val = None
        self._use_train = None
        self._use_loss = None

    def _compare_decrease(self, new_val: float):
        if new_val < self._last_val:
            return True
        else:
            return False

    def _compare_increase(self, new_val: float):
        if new_val > self._last_val:
            return True
        else:
            return False

    def _set_new_val(self, train_loss: float, val_loss: float, train_metric: float, val_metric: float):
        if self._use_train:
            if self._use_loss:
                new_val = train_loss
            else:
                new_val = train_metric
        else:
            if self._use_loss:
                new_val = val_loss
            else:
                new_val = val_metric

        return new_val


class EarlyStoppingCallback(BaseCallbacks):
    def __init__(self, patience: int, use_train: bool = True, use_loss: bool = True, should_decrease: bool = True):
        """
        :param patience: how long to wait without any improvement to finish training
        :param use_train: True if should use train data for deciding, False for using validation data
        :param use_loss: True for using loss value for deciding on a save, False for using the training main metric
        :param should_decrease: True if the value used for stopping should decrease with time, false if should increase
        """
        super().__init__()
        self._use_train = use_train
        self._use_loss = use_loss
        self._original_patience = patience
        self._patience = patience
        self._should_decrease = should_decrease
        self._last_val = np.inf if should_decrease else -np.inf
        self._compare_func = self._compare_decrease if should_decrease else self._compare_increase

    def __call__(self, train_loss: float, val_loss: float, train_metric: float, val_metric: float):
        new_val = self._set_new_val(train_loss, val_loss, train_metric, val_metric)

        if self._compare_func(new_val):
            self._last_val = new_val
            self._patience = self._original_patience
        else:
            self._patience -= 1

        return self._patience == 0


class ModelCheckpointCallback(BaseCallbacks):
    def __init__(self, save_folder: str, use_train: bool = True, use_loss: bool = True, save_only_best: bool = True,
                 should_decrease: bool = True):
        """
        Best state model checkpoint
        :param save_folder: Full path to model save folder
        :param use_train: Should use train data for checkpoint calculation
        :param use_loss: True for using loss value for deciding on a save, False for using the training main metric
        :param save_only_best: if False will save all checkpoints,
        """
        super().__init__()
        self._save_folder = save_folder
        self._use_train = use_train
        self._use_loss = use_loss
        self._last_val = np.inf if should_decrease else -np.inf
        self._compare_func = self._compare_decrease if should_decrease else self._compare_increase
        self._save_count = 1
        self._save_only_best = save_only_best
        self._model = None
        self._should_decrease = should_decrease

    @property
    def model(self):
        return self._model

    @model.setter
    def model(self, val):
        self._model = val

    def __call__(self, train_loss, val_loss, train_metric, val_metric, *args, **kwargs):
        new_val = self._set_new_val(train_loss, val_loss, train_metric, val_metric)
        should_save = self._compare_func(new_val)

        if should_save:
            self._last_val = new_val
            if self._save_only_best:
                dir_name = 'model_state_cp'
            else:
                dir_name = f'model_state_cp_{self._save_count}'
            full_save_path = os.path.join(self._save_folder, dir_name)

            if not os.path.exists(full_save_path):
                os.makedirs(full_save_path, exist_ok=True)
            torch.save(self._model.state_dict(), os.path.join(full_save_path, f'model.pt'))
            self._save_count += 1
        return should_save
